/*
 * Copyright (c) Facebook, Inc. and its affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 *
 */

#include <folly/futures/Future.h>
#include <folly/io/async/test/MockAsyncTransport.h>
#include <folly/portability/GMock.h>
#include <folly/portability/GTest.h>

#include <quic/api/QuicStreamAsyncTransport.h>
#include <quic/api/test/Mocks.h>
#include <quic/client/QuicClientTransport.h>
#include <quic/common/test/TestUtils.h>
#include <quic/fizz/client/handshake/FizzClientHandshake.h>
#include <quic/fizz/client/handshake/FizzClientQuicHandshakeContext.h>
#include <quic/server/QuicServer.h>
#include <quic/server/QuicServerTransport.h>
#include <quic/server/test/Mocks.h>

using namespace testing;

namespace quic::test {

class QuicStreamAsyncTransportTest : public Test {
 public:
  void SetUp() override {
    folly::ssl::init();
    createServer();
    createClient();
  }

  void createServer() {
    auto serverTransportFactory =
        std::make_unique<MockQuicServerTransportFactory>();
    EXPECT_CALL(*serverTransportFactory, _make(_, _, _, _))
        .WillOnce(Invoke(
            [&](folly::EventBase* evb,
                std::unique_ptr<folly::AsyncUDPSocket>& socket,
                const folly::SocketAddress& /*addr*/,
                std::shared_ptr<const fizz::server::FizzServerContext> ctx) {
              auto transport = quic::QuicServerTransport::make(
                  evb, std::move(socket), serverConnectionCB_, std::move(ctx));
              CHECK(serverSocket_.get() == nullptr);
              serverSocket_ = transport;
              return transport;
            }));

    EXPECT_CALL(serverConnectionCB_, onNewBidirectionalStream(_))
        .WillOnce(Invoke([&](StreamId id) {
          serverAsyncWrapper_ =
              QuicStreamAsyncTransport::createWithExistingStream(
                  serverSocket_, id);
          serverAsyncWrapper_->setReadCB(&serverReadCB_);
        }));

    EXPECT_CALL(serverReadCB_, isBufferMovable_())
        .WillRepeatedly(Return(false));
    EXPECT_CALL(serverReadCB_, getReadBuffer(_, _))
        .WillRepeatedly(Invoke([&](void** buf, size_t* len) {
          *buf = serverBuf_.data();
          *len = serverBuf_.size();
        }));
    EXPECT_CALL(serverReadCB_, readDataAvailable_(_))
        .WillOnce(Invoke([&](auto len) {
          auto echoData = folly::IOBuf::copyBuffer("echo ");
          echoData->appendChain(
              folly::IOBuf::wrapBuffer(serverBuf_.data(), len));
          serverAsyncWrapper_->writeChain(&serverWriteCB_, std::move(echoData));
          serverAsyncWrapper_->shutdownWrite();
        }));
    EXPECT_CALL(serverReadCB_, readEOF_()).WillOnce(Return());

    EXPECT_CALL(serverWriteCB_, writeSuccess_()).WillOnce(Return());

    server_ = QuicServer::createQuicServer();
    auto serverCtx = test::createServerCtx();
    server_->setFizzContext(serverCtx);

    server_->setQuicServerTransportFactory(std::move(serverTransportFactory));

    folly::SocketAddress addr("::1", 0);
    server_->start(addr, 1);
    server_->waitUntilInitialized();
    serverAddr_ = server_->getAddress();
  }

  void createClient() {
    clientEvbThread_ = std::thread([&]() { clientEvb_.loopForever(); });

    EXPECT_CALL(clientConnectionCB_, onTransportReady()).WillOnce(Invoke([&]() {
      clientAsyncWrapper_ =
          QuicStreamAsyncTransport::createWithNewStream(client_);
      ASSERT_TRUE(clientAsyncWrapper_);
      clientAsyncWrapper_->setReadCB(&clientReadCB_);
      startPromise_.setValue();
    }));

    EXPECT_CALL(clientReadCB_, isBufferMovable_())
        .WillRepeatedly(Return(false));
    EXPECT_CALL(clientReadCB_, getReadBuffer(_, _))
        .WillRepeatedly(Invoke([&](void** buf, size_t* len) {
          *buf = clientBuf_.data();
          *len = clientBuf_.size();
        }));
    EXPECT_CALL(clientReadCB_, readDataAvailable_(_))
        .WillOnce(Invoke([&](auto len) {
          clientReadPromise_.setValue(
              std::string(reinterpret_cast<char*>(clientBuf_.data()), len));
        }));
    EXPECT_CALL(clientReadCB_, readEOF_()).WillOnce(Return());

    EXPECT_CALL(clientWriteCB_, writeSuccess_()).WillOnce(Return());

    auto [promise, future] = folly::makePromiseContract<folly::Unit>();
    startPromise_ = std::move(promise);

    clientEvb_.runInEventBaseThreadAndWait([&]() {
      auto sock = std::make_unique<folly::AsyncUDPSocket>(&clientEvb_);
      auto fizzClientContext =
          FizzClientQuicHandshakeContext::Builder()
              .setCertificateVerifier(test::createTestCertificateVerifier())
              .build();
      client_ = std::make_shared<QuicClientTransport>(
          &clientEvb_, std::move(sock), std::move(fizzClientContext));
      client_->setHostname("echo.com");
      client_->addNewPeerAddress(serverAddr_);
      client_->start(&clientConnectionCB_);
    });

    std::move(future).get(1s);
  }

  void TearDown() override {
    if (serverAsyncWrapper_) {
      serverAsyncWrapper_->getEventBase()->runInEventBaseThreadAndWait(
          [&]() { serverAsyncWrapper_.reset(); });
    }
    server_->shutdown();
    server_ = nullptr;
    clientEvb_.runInEventBaseThreadAndWait([&] {
      clientAsyncWrapper_ = nullptr;
      client_ = nullptr;
    });
    clientEvb_.terminateLoopSoon();
    clientEvbThread_.join();
  }

 protected:
  std::shared_ptr<QuicServer> server_;
  folly::SocketAddress serverAddr_;
  NiceMock<MockConnectionCallback> serverConnectionCB_;
  std::shared_ptr<quic::QuicSocket> serverSocket_;
  QuicStreamAsyncTransport::UniquePtr serverAsyncWrapper_;
  folly::test::MockWriteCallback serverWriteCB_;
  folly::test::MockReadCallback serverReadCB_;
  std::array<uint8_t, 1024> serverBuf_;

  std::shared_ptr<QuicClientTransport> client_;
  folly::EventBase clientEvb_;
  std::thread clientEvbThread_;
  NiceMock<MockConnectionCallback> clientConnectionCB_;
  QuicStreamAsyncTransport::UniquePtr clientAsyncWrapper_;
  folly::Promise<folly::Unit> startPromise_;
  folly::test::MockWriteCallback clientWriteCB_;
  folly::test::MockReadCallback clientReadCB_;
  std::array<uint8_t, 1024> clientBuf_;
  folly::Promise<std::string> clientReadPromise_;
};

TEST_F(QuicStreamAsyncTransportTest, ReadWrite) {
  auto [promise, future] = folly::makePromiseContract<std::string>();
  clientReadPromise_ = std::move(promise);

  std::string msg = "yo yo!";
  clientEvb_.runInEventBaseThreadAndWait([&] {
    clientAsyncWrapper_->write(&clientWriteCB_, msg.data(), msg.size());
    clientAsyncWrapper_->shutdownWrite();
  });

  std::string clientReadString = std::move(future).get(1s);
  EXPECT_EQ(clientReadString, "echo yo yo!");
}

} // namespace quic::test
